{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch.optim\n",
    "torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持，并且接口具备足 够的通用性，使得未来能够集成更加复杂的方法。\n",
    "\n",
    "## 如何使用optimizer\n",
    "为了使用 torch.optim ，你需要构建一个optimizer对象。这个对象能够保持当前参数状态并基 于计算得到的梯度进行参数更新。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建\n",
    "为了构建一个 Optimizer ，你需要给它一个包含了需要优化的参数 的iterable。然后，你可以设置optimizer的参 数选项，比如学习率，权重衰减，等等。\n",
    "```python\n",
    "optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9) \n",
    "optimizer = optim.Adam([var1, var2], lr = 0.0001)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 为每个参数单独设置选项\n",
    "Optimizer 也支持为每个参数单独设置选项。若想这么做，不要直接传入 iterable，而是传入 dict 的iterable。每一个dict都分别定义了一组参数，并且包含一个 param 键，这个键对应参数的列表。其他的键应该optimizer所接受的其他参数的关键字相匹配，并且会被用于对这组参数的优化。\n",
    "\n",
    "**注意：**\n",
    "你仍然能够传递选项作为关键字参数。在未重写这些选项的组中，它们会被用作默认值。当你只想改动一个参数组的选项，但其他参数组的选项不变时，这是非常有用的。  \n",
    "例如，当我们想指定每一层的学习率时，这是非常有用的:\n",
    "```python\n",
    "optim.SGD([{'params': model.base.parameters()}, {'params': model.classifier.parameters()], 'lr': 1e-3} ,lr=1e-2, momentum=0.9) \n",
    "```  \n",
    "这意味着 model.base 的参数将会使用 1e-2 的学习率， model.classifier 的参数将会使用 1e-3 的学习率，并且 0.9 的momentum将会被用于所有的参数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 进行单次优化\n",
    "所有的optimizer都实现了 step() 方法，这个方法会更新所有的参数。它能按两种方式来使 用：\n",
    "optimizer.step()\n",
    "这是大多数optimizer所支持的简化版本。一旦梯度被如 backward() 之类的函数计算好后，我 们就可以调用这个函数。\n",
    "\n",
    "例子:\n",
    "```python\n",
    "for input, target in dataset:     \n",
    "    optimizer.zero_grad()     \n",
    "    output = model(input)     \n",
    "    loss = loss_fn(output, target)     \n",
    "    loss.backward()     \n",
    "    optimizer.step() \n",
    "    optimizer.step(closure)\n",
    "```\n",
    "一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数，因此你需要传入一个闭包去允许它们重新计算你的模型。这个闭包应当清空梯度，计算损失，然后返回。  \n",
    "\n",
    "例子：\n",
    "```python\n",
    "for input, target in dataset:     \n",
    "    def closure():        \n",
    "        optimizer.zero_grad()        \n",
    "        output = model(input)        \n",
    "        loss = loss_fn(output, target)         \n",
    "        loss.backward()         \n",
    "        return loss     \n",
    "        optimizer.step(closure)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 函数\n",
    "### torch.optim.Optimizer\n",
    "class torch.optim.Optimizer(params, defaults)  \n",
    "Base class for all optimizers.  \n",
    "参数：  \n",
    "- params (iterable) —— dict 的iterable。指定了什么参数应当被优化。 \n",
    "- defaults —— (dict)：包含了优化选项默认值的字典（一个参数组没有指定的参数选项将会 使用默认值）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load_state_dict\n",
    "load_state_dict(state_dict)  \n",
    "加载optimizer状态  \n",
    "参数：  \n",
    "- state_dict ( dict ) —— optimizer的状态。应当是一个调用 state_dict() 所返回的对象。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state_dict\n",
    "state_dict()  \n",
    "以 dict 返回optimizer的状态，返回结果包含两项：  \n",
    "- state - 一个保存了当前优化状态的dict。optimizer的类别不同，state的内容也会不同。 \n",
    "- param_groups - 一个包含了全部参数组的dict。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### step\n",
    "step(closure)  \n",
    "进行单次优化 (参数更新)  \n",
    "参数：  \n",
    "- closure ( callable ) – 一个重新评价模型并返回loss的闭包，对于大多数参数来说是可选 的。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### zero_grad\n",
    "zero_grad()  \n",
    "清空所有被优化过的梯度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 算法\n",
    "### Adadelta\n",
    "class torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0) - 实现Adadelta算法。  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- rho ( float , 可选) – 用于计算平方梯度的运行平均值的系数（默认：0.9） \n",
    "- eps ( float , 可选) – 为了增加数值计算的稳定性而加到分母里的项（默认：1e-6） \n",
    "- lr ( float , 可选) – 在delta被应用到参数更新之前对它缩放的系数（默认：1.0） \n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认: 0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adagrad\n",
    "class torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0) - 实现Adagrad算法  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float , 可选) – 学习率（默认: 1e-2） \n",
    "- lr_decay ( float , 可选) – 学习率衰减（默认: 0） \n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认: 0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adamax\n",
    "class torch.optim.Adamax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) - 实现Adamax算法（Adam的一种基于无穷范数的变种)  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float , 可选) – 学习率（默认：2e-3） \n",
    "- betas (Tuple[ float , float ], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数 \n",
    "- eps ( float , 可选) – 为了增加数值计算的稳定性而加到分母里的项（默认：1e-8）\n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认: 0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adam\n",
    "class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e08, weight_decay=0) - 实现Adam算法  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float , 可选) – 学习率（默认：1e-3）\n",
    "- betas (Tuple[ float , float ], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数（默 认：0.9，0.999） \n",
    "- eps ( float , 可选) – 为了增加数值计算的稳定性而加到分母里的项（默认：1e-8）\n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认: 0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ASGD\n",
    "class torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0) - 实现平均随机梯度下降算法  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float , 可选) – 学习率（默认：1e-2） \n",
    "- lambd ( float , 可选) – 衰减项（默认：1e-4） \n",
    "- alpha ( float , 可选) – eta更新的指数（默认：0.75） \n",
    "- t0 ( float , 可选) – 指明在哪一次开始平均化（默认：1e6） \n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认: 0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LBFGS\n",
    "class torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None) - 实现L-BFGS算法  \n",
    "**警告：**  \n",
    "这个optimizer不支持为每个参数单独设置选项以及不支持参数组（只能有一个）  \n",
    "**警告：**   \n",
    "目前所有的参数不得不都在同一设备上。在将来这会得到改进。  \n",
    "**注意：**  \n",
    "这是一个内存高度密集的optimizer（它要求额外的 param_bytes * (history_size + 1) 个字 节）。如果它不适应内存，尝试减小history size，或者使用不同的算法。  \n",
    "参数：  \n",
    "- lr ( float ) – 学习率（默认：1） \n",
    "- max_iter ( int ) – 每一步优化的最大迭代次数（默认：20）) \n",
    "- max_eval ( int ) – 每一步优化的最大函数评价次数（默认：max * 1.25） \n",
    "- tolerance_grad ( float ) – 一阶最优的终止容忍度（默认：1e-5） \n",
    "- tolerance_change ( float ) – 在函数值/参数变化量上的终止容忍度（默认：1e-9)\n",
    "- history_size ( int ) – 更新历史的大小（默认：100）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RMSprop\n",
    "class torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) - 实现RMSprop算法  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float , 可选) – 学习率（默认：1e-2） \n",
    "- momentum ( float , 可选) – 动量因子（默认：0） \n",
    "- alpha ( float , 可选) – 平滑常数（默认：0.99） \n",
    "- eps ( float , 可选) – 为了增加数值计算的稳定性而加到分母里的项（默认：1e-8） \n",
    "- centered ( bool , 可选) – 如果为True，计算中心化的RMSProp，并且用它的方差预测值对 梯度进行归一化 \n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认: 0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rprop\n",
    "class torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes= (1e-06, 50)) - 实现弹性反向传播算法  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float , 可选) – 学习率（默认：1e-2） \n",
    "- etas (Tuple[ float , float ], 可选) – 一对（etaminus，etaplis）, 它们分别是乘法的增加和减 小的因子（默认：0.5，1.2） \n",
    "- step_sizes (Tuple[ float , float ], 可选) – 允许的一对最小和最大的步长（默认：1e-6， 50）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SGD\n",
    "class torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False) - 实现随机梯度下降算法（momentum可选）  \n",
    "参数：  \n",
    "- params (iterable) – 待优化参数的iterable或者是定义了参数组的dict \n",
    "- lr ( float ) – 学习率 \n",
    "- momentum ( float , 可选) – 动量因子（默认：0）\n",
    "- weight_decay ( float , 可选) – 权重衰减（L2惩罚）（默认：0） dampening ( float , 可选) – 动量的抑制因子（默认：0） nesterov ( bool , 可选) – 使用Nesterov动量（默认：False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上就是所有优化算法，同学们可以自行测试每个优化算法，只需修改optimizer=optim优化算法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 示例代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch[200/10000], loss: 20.570246\n",
      "Epoch[400/10000], loss: 19.275066\n",
      "Epoch[600/10000], loss: 18.040579\n",
      "Epoch[800/10000], loss: 16.864079\n",
      "Epoch[1000/10000], loss: 15.743072\n",
      "Epoch[1200/10000], loss: 14.675252\n",
      "Epoch[1400/10000], loss: 13.658504\n",
      "Epoch[1600/10000], loss: 12.690866\n",
      "Epoch[1800/10000], loss: 11.770534\n",
      "Epoch[2000/10000], loss: 10.895849\n",
      "Epoch[2200/10000], loss: 10.065274\n",
      "Epoch[2400/10000], loss: 9.277388\n",
      "Epoch[2600/10000], loss: 8.530876\n",
      "Epoch[2800/10000], loss: 7.824512\n",
      "Epoch[3000/10000], loss: 7.157146\n",
      "Epoch[3200/10000], loss: 6.527702\n",
      "Epoch[3400/10000], loss: 5.935160\n",
      "Epoch[3600/10000], loss: 5.378541\n",
      "Epoch[3800/10000], loss: 4.856909\n",
      "Epoch[4000/10000], loss: 4.369348\n",
      "Epoch[4200/10000], loss: 3.914962\n",
      "Epoch[4400/10000], loss: 3.492862\n",
      "Epoch[4600/10000], loss: 3.102158\n",
      "Epoch[4800/10000], loss: 2.741949\n",
      "Epoch[5000/10000], loss: 2.411316\n",
      "Epoch[5200/10000], loss: 2.109313\n",
      "Epoch[5400/10000], loss: 1.834961\n",
      "Epoch[5600/10000], loss: 1.587235\n",
      "Epoch[5800/10000], loss: 1.365064\n",
      "Epoch[6000/10000], loss: 1.167317\n",
      "Epoch[6200/10000], loss: 0.992804\n",
      "Epoch[6400/10000], loss: 0.840263\n",
      "Epoch[6600/10000], loss: 0.708362\n",
      "Epoch[6800/10000], loss: 0.595696\n",
      "Epoch[7000/10000], loss: 0.500787\n",
      "Epoch[7200/10000], loss: 0.422087\n",
      "Epoch[7400/10000], loss: 0.357992\n",
      "Epoch[7600/10000], loss: 0.306850\n",
      "Epoch[7800/10000], loss: 0.266990\n",
      "Epoch[8000/10000], loss: 0.236745\n",
      "Epoch[8200/10000], loss: 0.214486\n",
      "Epoch[8400/10000], loss: 0.198666\n",
      "Epoch[8600/10000], loss: 0.187861\n",
      "Epoch[8800/10000], loss: 0.180804\n",
      "Epoch[9000/10000], loss: 0.176422\n",
      "Epoch[9200/10000], loss: 0.173849\n",
      "Epoch[9400/10000], loss: 0.172426\n",
      "Epoch[9600/10000], loss: 0.171687\n",
      "Epoch[9800/10000], loss: 0.171322\n",
      "Epoch[10000/10000], loss: 0.171146\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xt4VNW9xvHvSgjECIpcLAiEAUQBuQQIWoQqchMBkSIqp6mneGxR64UeEQWDotggFo9KH2+Nl4LHtD0UBfECRUUEFJUEQbkVRAJGkJtyiQEMZJ0/JgyZYUImyUz2npn38zw8k71mZ/aPgbyzsvbaaxtrLSIiElsSnC5ARETCT+EuIhKDFO4iIjFI4S4iEoMU7iIiMUjhLiISgxTuIiIxSOEuIhKDFO4iIjGollMHbtSokfV4PE4dXkQkKuXl5e211jauaD/Hwt3j8ZCbm+vU4UVEopIxZlso+2lYRkQkBincRURikMJdRCQGOTbmHkxxcTEFBQUcOXLE6VIESE5Opnnz5iQlJTldiohUkqvCvaCggHr16uHxeDDGOF1OXLPWsm/fPgoKCmjVqpXT5YhIJblqWObIkSM0bNhQwe4CxhgaNmyo36JEopSrwh1QsLuI/i1Eopfrwl1EJFYdKT7OE+9uYsf+wxE/lsI9QEFBAddccw1t27alTZs2jB07lp9++inovjt27GDkyJEVvubgwYPZv39/lep56KGHePzxxyvcr27duqd9fv/+/Tz77LNVqkFEqm927je0e2Ahf35/M0s37Yn48aI73HNywOOBhATvY05OtV7OWsuIESMYPnw4mzdvZtOmTRQWFpKZmXnKvseOHeO8885jzpw5Fb7uO++8Q/369atVW3Up3EWcceBwMZ4Jb3PvnC8AGJ52HqMuTo34caM33HNyYMwY2LYNrPU+jhlTrYBfvHgxycnJ3HTTTQAkJiby5JNP8vLLL1NUVMTMmTO57rrruPrqqxk4cCD5+fl07NgRgKKiIq6//no6d+7MDTfcwCWXXOJbXsHj8bB3717y8/Np3749v/vd77jooosYOHAghw97fz174YUX6NGjB126dOHaa6+lqKjotLVu3bqVnj170qNHDx544AFfe2FhIf369aNbt2506tSJN954A4AJEyawZcsW0tLSGD9+fLn7iUj4PP/hFro8vMi3vXT8FTw1qmuNHDt6wz0zEwIDsKjI215F69ato3v37n5tZ511FqmpqXz11VcArFixglmzZrF48WK//Z599lnOOeccvvjiCx544AHy8vKCHmPz5s3cfvvtrFu3jvr16/Paa68BMGLECFauXMmaNWto3749L7300mlrHTt2LLfddhsrV66kSZMmvvbk5GTmzp3LqlWr+OCDDxg3bhzWWqZNm0abNm1YvXo106dPL3c/Eam+3QeP4JnwNtMWbATglstakz9tCKkNU2qsBlfNc6+U7dsr1x4Ca23QGSJl2wcMGECDBg1O2Wf58uWMHTsWgI4dO9K5c+egx2jVqhVpaWkAdO/enfz8fADWrl3LpEmT2L9/P4WFhVx55ZWnrfWjjz7yfTDceOON3Hfffb5a77//fpYuXUpCQgLffvstu3btCvp3CrZf2Q8KEam8R95az0vLt/q2V2b2p3G9OjVeR/SGe2qqdygmWHsVXXTRRb7APOHgwYN88803tGnThry8PM4888yg3xtqr7dOnZP/yImJib5hmdGjRzNv3jy6dOnCzJkzWbJkSYWvFeyDKCcnhz179pCXl0dSUhIejyfoXPVQ9xOR0OTv/ZE+jy/xbWcObs/vLmvtWD0VDssYY5KNMZ8ZY9YYY9YZYx4Oss9oY8weY8zq0j+/jUy5ZWRlQUrArzgpKd72KurXrx9FRUW88sorABw/fpxx48YxevRoUgKPFaB3797Mnj0bgPXr1/Pll19W6tiHDh2iadOmFBcXkxPCeYNevXrxj3/8A8Bv/wMHDnDuueeSlJTEBx98wLbSD8B69epx6NChCvcTkcq78++f+wX7Fw8NdDTYIbQx96NAX2ttFyANGGSM+XmQ/f7PWptW+ufFsFYZTEYGZGdDy5ZgjPcxO9vbXkXGGObOncs///lP2rZtywUXXEBycjJTp06t8Ht///vfs2fPHjp37sxjjz1G586dOfvss0M+9iOPPMIll1zCgAEDaNeuXYX7z5gxg2eeeYYePXpw4MABX3tGRga5ubmkp6eTk5Pje62GDRvSq1cvOnbsyPjx48vdT0RCt/bbA3gmvM2ba3YA8Ph1XcifNoSzkp1fj8lU5iSaMSYFWA7cZq39tEz7aCDdWntHqK+Vnp5uA2/WsWHDBtq3bx9yPW5y/PhxiouLSU5OZsuWLfTr149NmzZRu3Ztp0urlmj+NxGJlJISy6jsT/gs/3sAzklJYsXEfiQnJUb82MaYPGttekX7hTRbxhiTaIxZDewG3i0b7GVca4z5whgzxxjTopL1Rr2ioiJ69+5Nly5d+OUvf8lzzz0X9cEuIqf6eMteWt//ji/YXx6dzucPDgwt2MN8bc7phHRC1Vp7HEgzxtQH5hpjOlpr15bZ5U3g79bao8aYW4FZQN/A1zHGjAHGAKRW48SnG9WrV0+3DRSJYcXHS+j/xIds2+edgt2uST3evusXJCaEuAbTiWtzTkzhPnFtDlRrOLk8lZrnbq3dDywBBgW077PWHi3dfAHoThDW2mxrbbq1Nr1x4wrv7yoi4goL1+6kbeYCX7DPubUnC/9wWejBDhG5Nud0Kuy5G2MaA8XW2v3GmDOA/sBjAfs0tdbuLN0cBmwIe6UiIjXs8E/H6frIIo4UlwBw2QWNmXVTj6qtmBqBa3NOJ5RhmabALGNMIt6e/mxr7VvGmClArrV2PnCXMWYYcAz4HhgdkWpFRGrI3z7dzv1zT05p/tcfLuPCJvWq/oIRuDbndCoMd2vtF8ApiyFYax8s8/VEYGJ4SxMRqXn7i34ibcq7vu3rujdn+nVdqv/CWVn+Y+5Q7WtzTid615aJkMTERNLS0nx/8vPzyc3N5a677gJgyZIlfPzxx779582bx/r1633bDz74IO+9915Yajmx4FhZ8+fPZ9q0aWF5fRHx9/TizX7BvuzeK8IT7BCRa3NOJ3qXH4iQM844g9WrV/u1eTwe0tO900qXLFlC3bp1ufTSSwFvuA8dOpQOHToAMGXKlIjWN2zYMIYNGxbRY4jEm+8OHOHnj77v2779ijaMvzICF/ZlZEQszAOp5x6CJUuWMHToUPLz83n++ed58sknSUtL48MPP2T+/PmMHz+etLQ0tmzZwujRo31rvHs8HiZPnuxbVnfjRu8KcXv27GHAgAF069aNW265hZYtW57SQy/PzJkzueMO77Vio0eP5q677uLSSy+ldevWfmvLT58+nR49etC5c2cmT54c5ndEJHZMfmOtX7DnTeofmWCvYa7tuT/85jrW7zgY1tfscN5ZTL76otPuc/jwYd+qja1atWLu3Lm+5zweD7feeit169blnnvuAbw96aFDh5Z7R6ZGjRqxatUqnn32WR5//HFefPFFHn74Yfr27cvEiRNZuHAh2dnZVf477dy5k+XLl7Nx40aGDRvGyJEjWbRoEZs3b+azzz7DWsuwYcNYunQpl112WZWPIxKVcnK8Uw23b/eeuMzK8vWct+wppN//fOjb9cGhHfiv3q2cqjTsXBvuTgk2LFMdI0aMALzL+77++uuAd3ngEx8agwYN4pxzzqny6w8fPpyEhAQ6dOjgW9p30aJFLFq0iK5dvefBCwsL2bx5s8Jd4ks5Fw1ZC7fZdixc951v17UPX0ndOrEVh67921TUw44WJ5b4TUxM5NixY0DoywNX5vXLvq61lokTJ3LLLbeE7TgiUSfIRUNfnHUew9bWB7zBPmNUGtekNXOguMjTmHslBS6dG7gdirLLAy9atIgffvghrDVeeeWVvPzyyxQWFgLw7bffsnv37rAeQ8T1ylwcVIJh+I2PM+w3TwFwbr06/PuPg2I22EHhXmlXX301c+fOJS0tjWXLljFq1CimT59O165d2bJlS0ivMXnyZBYtWkS3bt1YsGABTZs2pV694BdHdO7cmebNm9O8eXPuvvvukF5/4MCB/OpXv6Jnz5506tSJkSNHVvoDSCTqlV4c9LcuV9L6vjdZfZ73JOnMJc/wWWZ/6tSK/AqOTqrUkr/hFGtL/lbG0aNHSUxMpFatWqxYsYLbbrstrOP84RQv/yYSe4r+N4cO6+r7tjvt3My81yaR+Je/1Nh0xEgIdclf1465x7Lt27dz/fXXU1JSQu3atXnhhRecLkkkpvw+J493ygT7Q+/9hdF7v4AoD/bKULg7oG3btnz++edOlyESc/YWHiX9j/5XiG99dDBm2hCHKnKO68LdWlu1Fdck7JwashOpikFPLWXjdyfPLT2X0Y2rOjV1sCJnuSrck5OT2bdvHw0bNlTAO8xay759+0hOTna6FJHT+npPIX3LXIwEkB+HPfVArgr35s2bU1BQwJ49e5wuRfB+2DZv3tzpMkTK5Znwtt/2a7f1pHvLBg5V4y6uCvekpCRatYqdy39FJDLytn3Ptc+t8GtTb92fq8JdRKQigb3198ddTpvGdR2qxr0U7iISFRau3cmtr67ybbc9ty7v3n25gxW5m65QFamMnBzweCAhwfuYk+N0RTHPWotnwtt+wb4ys7+CvQLquYuEqpxVBoG4uTCmpv31o608/ObJO51d1bEJz/26u4MVRQ9XLT8g4moeT/AbHLdsCfn5NV1NTCs+XkLbzAV+beunXElKbfVHtfyASLiVWWUwpHapkilvruflj7b6tm+9vA0Tror+OyPVNIW7SKhSU4P33EtXH5TqKTx6jI6T/+XX9lXWVdRK1KnBqtC7JhKqrCxISfFvS0nxtku13DxzpV+wPzK8I/nThijYq0E9d5FQnThpWs49OaXydh88wsVT3/dr2/roYC0/EgYKd5HKyMhQmIfJ5dM/YNu+k7fBe/E/0+nf4WcOVhRbFO4iUqM27zrEgCeX+rVp6YDwU7iLSI0JXDpg3u29SGtRv5y9pToU7iIScZ98vY9R2Z/4tuvUSuDff7zKwYpin8JdRCIqsLf+4fg+tGx4pkPVxA+Fu4hExJtrdnDn30/eTrJTs7N5887eDlYUXxTuIhJW1lpaTXzHr23VAwNocGZthyqKTwp3EQmbv3y4hUcXbPRtD087j6dGdXWwovhVYbgbY5KBpUCd0v3nWGsnB+xTB3gF6A7sA26w1uaHvVoRcaWfjpVwwST/hb42PjKI5KREhyqSUHruR4G+1tpCY0wSsNwYs8Ba+0mZfW4GfrDWnm+MGQU8BtwQgXpFxGUmzfuSVz85uXjaXf3acveACxysSCCEcLfeNYELSzeTSv8ErhN8DfBQ6ddzgKeNMcY6tZ6wiETcwSPFdH5okV/blqmDSUzQ0gFuENKYuzEmEcgDzgeesdZ+GrBLM+AbAGvtMWPMAaAhsDeMtYqIS/z6xU9Z/tXJH+/Hru3EDT20OqabhBTu1trjQJoxpj4w1xjT0Vq7tswuwT6qT+m1G2PGAGMAUrVMqkjU2XngMD0fXezXpqUD3KlSs2WstfuNMUuAQUDZcC8AWgAFxphawNnA90G+PxvIBu+dmKpYs4g44JKp77Hr4FHf9sybetDnwnMdrEhOJ5TZMo2B4tJgPwPoj/eEaVnzgd8AK4CRwGKNt4vEhg07D3LVjGV+beqtu18oPfemwKzScfcEYLa19i1jzBQg11o7H3gJ+F9jzFd4e+yjIlaxiNSYwKUD3rqzNx2bne1QNVIZocyW+QI45SoEa+2DZb4+AlwX3tJExCkffbWXjBdPzps4+4wk1kwe6GBFUlm6h5VIrMvJAY8HEhK8jzk5p93dM+Ftv2Bfdu8VCvYopOUHRGJZTg6MGQNFpXc82rbNuw2n3FHq9VUF3D17jW+7h+cc/nnrpTVVqYSZceq8Z3p6us3NzXXk2CJxw+PxBnqgli0hPx+AkhJL6/v9F/pa8+BAzk5Jinx9UmnGmDxrbXpF+6nnLhLLtm8/bfvTizfz+KJNvubr05vzp5FdaqIyiTCFu0gsS00N2nM/4mlNu4CZMFroK7bohKpILMvKgpQUv6Z7h95Nu+tn+LbvGXgB+dOGKNhjjMJdJFIqOUslIjIyIDsbWrZk/xn18Nz3FrMv6ut7+uupg7mjb9uar0siTsMyIpFQiVkqEZeRgefL+n5NT97QhV92bV6zdUiNUs9dJBIyM08G+wlFRd72GrR+x8FTrjLNnzZEwR4H1HMXiYQKZqnUhMBQnzaiE6Mu1mqs8ULhLhIJ5cxSoQaWul68cRf/NdP/GhIt9BV/FO4ikZCV5T/mDt5ZK1lZET1sYG/91ZsvoXfbRhE9priTxtzjhRtmbsSTMrNUMMb7mJ0dsZOpMz/aGnRsXcEev9RzjwdumrkRTzIyIv7+WmtpNdF/6YB3//sy2v6sXkSPK+6nnns8cMnMDQmvB+atPSXY86cNUbALoJ57fHDBzA0Jn2PHSzg/c4FfW+6k/jSqW8ehisSNFO7xwMGZGxJew5/5iNXf7PdtN6t/Bh9N6Hua75B4pXCPBw7N3JDw2V/0E2lT3vVr00JfcjoK93hw4qReZqZ3KCY11RvsOpkaFQJnwbRvehYLxv7CoWokWijc40UNzNyQ8PpqdyH9n/jQr+3rqYNJSDAOVSTRROEu4kKBvfVBFzXh+Ru7O1SNRCOFu4iLLN20h/98+TO/Ni0dIFWhcBdxicDe+j0DL9Ba61JlCncRh836OJ/J89f5tam3LtWlcBdxUGBv/flfd2NQx6YOVSOxRMsPSOxz4aJpE1//IuhCXwp2CRf13CW2uWzRtGALfb11Z286Nju7xmuR2GastY4cOD093ebm5la8o0h1eDzBl15o2RLy82u0lEFPLWXjd4f82jS2LpVljMmz1qZXtJ967hLbXLBo2tFjx7lw0kK/ts/u78e5ZyXXWA0SfxTuEtscXjQtcFwd1FuXmqETqhLbsrK8i6SVVQOLpu0tPHpKsG98ZJCCXWqMwl1iWw3f7g68vfX0P77n227V6Ezypw2p/gqOLpz1I+5V4bCMMaYF8ArQBCgBsq21MwL26QO8AWwtbXrdWjslvKWKVFENLZq2avsPjHj2Y7+2rY8OxpgwLPTlslk/4n6hjLkfA8ZZa1cZY+oBecaYd6216wP2W2atHRr+EkXcL3AI5pq085gxqmv4DnC6WyUq3CWICsPdWrsT2Fn69SFjzAagGRAY7iJx55+53zB+zhd+bREZV3fBrB+JLpWaLWOM8QBdgU+DPN3TGLMG2AHcY61dF2QfkZgR2Fu/uXcrHhjaITIH060SpZJCDndjTF3gNeAP1tqDAU+vAlpaawuNMYOBecApy9kZY8YAYwBS9Z9SotTkN9Yya4V/0EZ8FoxulSiVFNIVqsaYJOAt4F/W2idC2D8fSLfW7i1vH12hKtEosLf+xPVdGNGtec0cPCdHt0qU8F2haryn+l8CNpQX7MaYJsAua601xlyMd4rlvkrWLOJag2csY/1O/19Ya3zOum6VKJUQyrBML+BG4EtjzOrStvuBVABr7fPASOA2Y8wx4DAwyjq1aI1IGJWUWFrf77/Q17zbe5HWor5DFYmEJpTZMsuB007UtdY+DTwdrqJE3EBLB0g009oyIgF+PHqMiyb/y6/t0/v78TMt9CVRROEuUoZ66xIrFO4iwDffF/GLP33g17bxkUHVXw9GxCEKd4l76q1LLFK4S9xasWUf//HCJ35tYVvoS8RhCneJS4G99UvbNORvv/u5Q9WIhJ/CXeLKKyvyefAN/2WPNAQjsUjhLnEjsLd+Z9/zGTfwQoeqEYkshbvEvKfe28RT7232a1NvXWKdwl1iWmBv/ZlfdWNI56YOVSNScxTuEpN+OyuX9zbs8mtTb13iicJdYsrxEkubgIW+Fo+7nNaN6zpUkYgzFO4SM7pOWcQPRcV+beqtS7xSuEvUKzx6jI4BC32teXAgZ6ckOVSRiPMU7hLVtHSASHAKd4lKBT8U0fsx/4W+NmddRVJigkMVibiLwl2iTmBv/WJPA2bf2tOhakTcSeEuUSNv2/dc+9wKvzYNwYgEp3CXqBDYW/9t71ZMGtrBoWpE3E/hLq72+qoC7p69xq9NvXWRiincxbUCe+t/GtmZ69NbOFSNSHRRuIvrPLpgA3/58Gu/NvXWRSpH4S6uEthbn31LTy5u1cChakSil8JdXOFXL3zCx1v2+bWpty5SdQp3cdSx4yWcn7nAr23ZvVfQokGKQxWJxAaFuzimbeY7FB+3fm3qrYuEh67Vlhp34HAxnglv+wX7l3PGkf+nq8HjgZwc54oTiRHquUuNCjxhWjfBsnbGDVBU5G3Ytg3GjPF+nZFRw9WJxA713KVGfHfgyCnBvmXqYNb+7Y6TwX5CURFkZtZgdSKxRz13ibjAUO9zYWNm3nSxd2P79uDfVF67iIRE4S4Rs27HAYb8eblf2yknTFNTvUMxgVJTI1iZSOxTuEtEBPbWH7u2Ezf0CBLYWVneMfayQzMpKd52EakyhbuE1fsbdnHzrFy/ttNObzxx0jQz0zsUk5rqDXadTBWplgrD3RjTAngFaAKUANnW2hkB+xhgBjAYKAJGW2tXhb9ccbPA3nrOby+h1/mNKv7GjAyFuUiYhdJzPwaMs9auMsbUA/KMMe9aa9eX2ecqoG3pn0uA50ofJQ789aOtPPzmer82XYwk4qwKw91auxPYWfr1IWPMBqAZUPan+RrgFWutBT4xxtQ3xjQt/V6JUdZaWk18x6/tvbsv4/xz6zlUkYicUKkxd2OMB+gKfBrwVDPgmzLbBaVtfuFujBkDjAFI1WyIqDZp3pe8+on/dEX11kXcI+RwN8bUBV4D/mCtPRj4dJBvsac0WJsNZAOkp6ef8ry4X7CFvnIn9adR3ToOVSQiwYQU7saYJLzBnmOtfT3ILgVA2VvkNAd2VL88cZNrn/uYvG0/+LZbNDiDZff2dbAiESlPKLNlDPASsMFa+0Q5u80H7jDG/APvidQDGm+PHYeOFNPpoUV+bRsfGURyUqJDFYlIRULpufcCbgS+NMasLm27H0gFsNY+D7yDdxrkV3inQt4U/lLFCYHL8l7VsQnP/bq7gxWJSChCmS2znOBj6mX3scDt4SpKnFfwQxG9H/vAr+3rqYNJSDjtfwURcQldoSqnCLwY6a5+bbl7wAUOVSMiVaFwF5813+znmmc+8mvT9EaR6KRwF+DU3vpTN6QxvGszh6oRkepSuMe5hWt3cuur/ssAqbcuEv0U7nEssLc++5aeXNyqgUPViEg4Kdzj0PMfbmHago1+beqti8QWhXscCbbQ1wf39KFVozMdqkhEIkXhHifGzV7Da6sK/NrUWxeJXQr3GPfTsRIumOS/0NfqBwdQP6W2QxWJSE1QuMewq2YsY8POkwt4tmtSj4V/uMzBikSkpijcY9CBomK6TPFf6OvffxxEnVpa6EskXijcY0zg9MZfdm3GkzekOVSNiDhF4R4jdh86wsVZ7/u1bX10MN4Vm0Uk3ijcY0C//1nClj0/+rbvHXQhv+9zvoMViYjTFO5R7KvdhfR/4kO/Nk1vFBFQuEetwLH11267lO4tz3GoGhFxmwSnC4gpOTng8UBCgvcxJyfsh1iZ/71fsBvj7a0r2EWkLIV7uOTkwJgxsG0bWOt9HDMmrAHvmfA21z2/wrf9wT192PqohmFcpQY+4EVCoXAPl8xMKCrybysq8rZX09tf7PTrrbdrUo/8aUO0Jozb1MAHvEiojPf2pzUvPT3d5ubmOnLsiEhI8P5ABzIGSkqq9JLBFvrKndSfRnXrVOn1JMI8Hm+gB2rZEvLza7oaiVHGmDxrbXpF+6nnHi6pqZVrr8CLy772C/YhnZqSP21I9AZ7PAxXbN9euXaRCNJsmXDJyvL+Cl52aCYlxdteCcXHS2ib6b/Q1/opV5JSO4r/qU4MV5x4b04MVwBkZDhXV7ilpgbvuVfxA16kOtRzD5eMDMjO9v4Kboz3MTu7UuH10Px1fsH++z5tyJ82JLqDHSJ6PsJVsrK8H+hlVeEDXiQcNObuAoeOFNPpIf+FvrZMHUxiQowsHRCB8xGulZPj/dDavt3bY8/Kiq3fTsRxoY65R3mXMPr95uXP+HDTHt/21F924leXxNiv8fE0XJGRoTAXV1C4O+S7A0f4+aNxstBXmM5HiEjoFO4O6P3YYgp+OOzbfuk36fRr/zMHK4qwEz1ZDVeI1BidUK1Bm3YdwjPhbb9gz582JDzB7vaphhkZ3rneJSXeRwW7SESp515DAhf6euP2XnRpUT88Lx4vUw1FJGTquUfYx1v2+gX7mbUTyZ82JHzBDvEz1VBEQqaeewQF9taXjr+C1IYp5exdDboyUkQCqOceAW+s/tYv2Lu0qE/+tCGRCXYI+9IHIhL9Kuy5G2NeBoYCu621HYM83wd4A9ha2vS6tXZKOIuMFsEW+vr8gQGcc2btyB5YUw1FJEAoPfeZwKAK9llmrU0r/ROXwf7G6m/9gn1E12bkTxsS+WCHsCx9ICKxpcKeu7V2qTHGE/lSolOwhb7+/cdB1KmVWLOF6MpIESkjXCdUexpj1gA7gHustevC9Lqulr10C1Pf2ejbnj6yM9elt3CwIhERr3CE+yqgpbW20BgzGJgHtA22ozFmDDAGIDWKT/b9ePQYF03+l1/b11MHkxArC32JSNSr9mwZa+1Ba21h6dfvAEnGmEbl7JttrU231qY3bty4uod2xJy8Ar9g/+tNPcifNkTBLiKuUu2euzGmCbDLWmuNMRfj/cDYV+3KXObgkWI6l1mW94ykRDY8UtF5ZhERZ4QyFfLvQB+gkTGmAJgMJAFYa58HRgK3GWOOAYeBUdapReIjJHBsfck9ffDo5tQi4mKhzJb5jwqefxp4OmwVucjuQ0e4OOvksrw3927FA0M7OFiRiEhotPxAObLeXs8Ly7b6tj+7vx/nnpXsYEUiIqFTuAfYtu9HLp++xLd936B23NanjXMFiYhUgcK9jLH/+Jw3Vu/wba+ZPJCzz0hysCIRkapRuAPrdhxgyJ+X+7Y0i8MqAAAFhElEQVT/NLIz1+tiJBGJYnEd7tZaRmV/wqdbvwegXnItVmb2JzmphpcOEBEJs7gN90++3seo7E982y/8ZzoDOsTwfUxFJK7EXbgfO17CgCeXsnXvjwCcf25dFo79BbUStbS9iMSOuAr3hWu/49ZX83zbs2/pycWtGjhYkYhIZMRFd/VI8XE6PLjQF+y9zm/I1kcHVy3Yc3LA44GEBO9jTk5YaxURCYeY77n/38rt3Pfal77tBWN/QfumZ1XtxXJy/O94tG2bdxu0lrqIuIpxahmY9PR0m5ubG7HXP1BUTJcpJxf6GtGtGU9cn1a9F/V4vIEeqGVLyM+v3muLiITAGJNnrU2vaL+Y7Lk/88FXTP/Xv33by+69ghYNwnBz6u3bK9cuIuKQmAr3XQePcMnUkwt93Xp5GyZc1S58B0hNDd5zj+Ibj4hIbIqZcH9o/jpmfpzv216Z2Z/G9eqE9yBZWf5j7gApKd52EREXifpw37r3R654fIlve9KQ9vz2F60jc7ATJ00zM71DMamp3mDXyVQRcZnoCvecHF+w2tRU7rjlSd4+UNv39JcPDaRecoQX+srIUJiLiOtFT7iXmYb45c/acPWoGXDA+9QT13dhRLfmztYnIuIi0XMRU2YmFBXxzVnncvXoGQA0/HE/G2ePVbCLiASInp576XTDuj8dplf+am5eOY++X+eCMQ4XJiLiPtET7qXTEM85coic/5vk3y4iIn6iZ1gmK8s77bAsTUMUEQkqesI9IwOys72X+hvjfczO1swVEZEgomdYBjQNUUQkRNHTcxcRkZAp3EVEYpDCXUQkBincRURikMJdRCQGOXYnJmPMHiDI4uinaATsjXA50UjvS/n03gSn96V80fTetLTWNq5oJ8fCPVTGmNxQbikVb/S+lE/vTXB6X8oXi++NhmVERGKQwl1EJAZFQ7hnO12AS+l9KZ/em+D0vpQv5t4b14+5i4hI5UVDz11ERCrJleFujGlhjPnAGLPBGLPOGDPW6ZrcxBiTaIz53BjzltO1uIkxpr4xZo4xZmPp/52eTtfkFsaY/y79WVprjPm7MSbZ6ZqcYox52Riz2xiztkxbA2PMu8aYzaWP5zhZYzi4MtyBY8A4a2174OfA7caYDg7X5CZjgQ1OF+FCM4CF1tp2QBf0HgFgjGkG3AWkW2s7AonAKGerctRMYFBA2wTgfWttW+D90u2o5spwt9butNauKv36EN4f0mbOVuUOxpjmwBDgRadrcRNjzFnAZcBLANban6y1+52tylVqAWcYY2oBKcAOh+txjLV2KfB9QPM1wKzSr2cBw2u0qAhwZbiXZYzxAF2BT52txDWeAu4FSpwuxGVaA3uAv5YOWb1ojDnT6aLcwFr7LfA4sB3YCRyw1i5ytirX+Zm1did4O5fAuQ7XU22uDndjTF3gNeAP1tqDTtfjNGPMUGC3tTbP6VpcqBbQDXjOWtsV+JEY+NU6HErHj68BWgHnAWcaY37tbFUSaa4Nd2NMEt5gz7HWvu50PS7RCxhmjMkH/gH0Nca86mxJrlEAFFhrT/yGNwdv2Av0B7Zaa/dYa4uB14FLHa7JbXYZY5oClD7udrieanNluBtjDN6x0w3W2iecrsctrLUTrbXNrbUevCfEFltr1QMDrLXfAd8YYy4sbeoHrHewJDfZDvzcGJNS+rPVD51sDjQf+E3p178B3nCwlrBw6z1UewE3Al8aY1aXtt1vrX3HwZrE/e4EcowxtYGvgZscrscVrLWfGmPmAKvwzkT7nBi8IjNUxpi/A32ARsaYAmAyMA2YbYy5Ge+H4XXOVRgeukJVRCQGuXJYRkREqkfhLiISgxTuIiIxSOEuIhKDFO4iIjFI4S4iEoMU7iIiMUjhLiISg/4fYveWZhphReIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n",
    "                    [9.779], [6.182], [7.59], [2.167], [7.042],\n",
    "                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n",
    "\n",
    "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n",
    "                    [3.366], [2.596], [2.53], [1.221], [2.827],\n",
    "                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)\n",
    "\n",
    "\n",
    "x_train = torch.from_numpy(x_train)\n",
    "\n",
    "y_train = torch.from_numpy(y_train)\n",
    "\n",
    "\n",
    "# Linear Regression Model\n",
    "class LinearRegression(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LinearRegression, self).__init__()\n",
    "        self.linear = nn.Linear(1, 1)  # input and output is 1 dimension\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.linear(x)\n",
    "        return out\n",
    "\n",
    "\n",
    "model = LinearRegression()\n",
    "# 定义loss和优化函数\n",
    "criterion = nn.MSELoss()\n",
    "#optimizer = optim.SGD(model.parameters(), lr=1e-4)\n",
    "optimizer=optim.Adam(model.parameters(),lr=1e-4)\n",
    "\n",
    "# 开始训练\n",
    "num_epochs = 10000  #epoch可自定义\n",
    "for epoch in range(num_epochs):\n",
    "    inputs = x_train\n",
    "    target = y_train\n",
    "\n",
    "    # forward\n",
    "    out = model(inputs)\n",
    "    loss = criterion(out, target)\n",
    "    # backward\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (epoch+1) % 200 == 0:\n",
    "        print('Epoch[{}/{}], loss: {:.6f}'\n",
    "              .format(epoch+1, num_epochs, loss.item()))\n",
    "\n",
    "model.eval()\n",
    "predict = model(x_train)\n",
    "predict = predict.data.numpy()\n",
    "plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')\n",
    "plt.plot(x_train.numpy(), predict, label='Fitting Line')\n",
    "# 显示图例\n",
    "plt.legend() \n",
    "plt.show()\n",
    "\n",
    "# 保存模型\n",
    "torch.save(model.state_dict(), './linear.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "torch.optim",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
